from tkinter import E
import openai
import asyncio
from typing import List, Dict, Any
import argparse
from collections import defaultdict
import os 
from tqdm import tqdm
import re 
import time
import json 
import random 


parser = argparse.ArgumentParser("")
parser.add_argument("--temperature", default=1, type=float, help="which seed to use")
parser.add_argument("--top_p", default=1.0, type=float, help="top_p for sampling")
parser.add_argument("--n_sample", default=10, type=int, help="number of examples to be generated")
parser.add_argument("--dataset", default='ncbi_disease', type=str, help="which model to use")

parser.add_argument("--model_name", default='gpt-3.5-turbo', type=str, help="which model to use")
parser.add_argument("--max_tokens", default=512, type=int, help="number of max tokens")
parser.add_argument("--output_dir", default='.', type=str, help="the folder for saving the generated text")
parser.add_argument("--keyword_type", default='.', type=str, help="kg or llm")

args = parser.parse_args()


api_key = '12345'  # placeholder
args.api_key = api_key

### Prompt Format ###
if args.dataset in ['ncbi_disease', 'bc5cdr_disease']:
    args.entity_list = ["Disease"]
    args.domain = 'Disease'
    args.n_label = 1
elif args.dataset in ['chemdner', 'bc5cdr_chemical']:
    args.entity_list = ["Chemical"]
    args.domain = 'Chemical'
    args.n_label = 1
else:
    raise NotImplementedError
#####################

async def dispatch_openai_requests(
    messages_list: List[List[Dict[str, Any]]],
    model: str,
    temperature: float,
    max_tokens: int,
    top_p: float,
) -> List[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
    Returns:
        List of responses from OpenAI API.
    """
    async_responses = [
        openai.ChatCompletion.acreate(
            model=model,
            messages=x,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
        )
        for x in messages_list
    ]
    return await asyncio.gather(*async_responses)


def load_demo(args):
    '''
    Loading the few-shot demonstration for synthetic data generation
    '''
    class_example = []
    with open(f"../data/{args.dataset}/train_few.jsonl", 'r') as f:
        for lines in f:
            example = json.loads(lines)
            keyword = {}
            for entity in args.entity_list:
                keyword[entity] = [] 
            text = " ".join(example["tokens"])
            for (start, end, entity_type) in example["spans"]:
                keyword[entity_type].append(" ".join(example["tokens"][start: (end + 1)]))
            example_dict = {"Sentence": text}
            for entity_type in keyword:
                example_dict[entity_type] = keyword[entity_type]
            class_example.append(example_dict)
    return class_example


def load_keywords(args):
    '''
    Loading the few-shot demonstration for synthetic data generation
    '''
    example = []
    dirs = os.listdir(f"../data/{args.dataset}/{args.keyword_type}/")
    for dir in dirs:
        with open(f"../data/{args.dataset}/{args.keyword_type}/{dir}", 'r') as f:
            for lines in f:
                text = lines.replace("\n", "")
                text = text.lstrip('-').strip("\"\',()[]").strip().lower()
                if text == "" or len(text) > 40:
                    continue
                example.append(text)
    return example


def gen_one_prompt(args, keywords, few_shot_demo):
    styles = ['biomedical articles', 'Pubmed articles']
    style = random.sample(styles, 1)[0]
    prompt_init = re.sub("_", " ", f"""
                                Suppose you need to create a dataset for {args.domain} recognition. Your task is to:\n1. generate a sentence about {args.domain},\n2. output a list of named entity about {args.domain} only.
                                """).strip()

    if random.random() > 0.06:
        topic_i = random.sample(keywords, 1)[0]
        prompt_init += f"\n3. the sentence should mimic the style of {style},\n4. the sentence should mention the {args.domain} named '{topic_i}'.\n"
    else:
        topic_i = ''
        prompt_init = f"\n3. the sentence should be written mimicing the style of {style}.\n"
    demo = f" Some examples are: \n"
    random.shuffle(few_shot_demo) #change the order of demonstration, improve the diversity
    for data in few_shot_demo:
        demo += "--------------------\n"
        demo += f"{json.dumps(data)}\n"
    demo += "--------------------\n"
    prompt = prompt_init+ demo
    return prompt, topic_i, style


def main(args):
    augment_entities = load_keywords(args)
    print(len(augment_entities))
    print(augment_entities[:10])

    openai.api_key = args.api_key

    few_shot_demo  = load_demo(args)
    example_cnt = 0 
    j = 0
    while example_cnt < args.n_sample:
        prefix = f"{args.domain}/train_p{args.top_p}_{j}.jsonl"
        prompts = [] 
        keywords=  []
        styles = []
        for _ in range(15):
            prompt, keyword, style = gen_one_prompt(args, augment_entities, few_shot_demo)
            prompts.append([{"role": "user", "content": prompt}])
            keywords.append(keyword)
            styles.append(style)
        os.makedirs(f"{args.output_dir}/{args.domain}/", exist_ok= True)
        f = open(f"{args.output_dir}/{prefix}", 'w')
        
        response = asyncio.run(
            dispatch_openai_requests(
                messages_list=prompts,
                model=args.model_name,
                temperature=args.temperature,
                max_tokens=args.max_tokens,
                top_p=args.top_p,
            )
        )
        # parse the output from LLM
        if 'gpt-3.5' in args.model_name or 'gpt-4' in args.model_name:
            ans = [x['choices'][0]['message']['content'] for x in response]
        else:
            ans = [x['choices'][0]['text'] for x in response]
        print(len(ans), len(styles), len(keywords))

        for text, style, keyword in zip(ans, styles, keywords):
            parsed_texts = text.strip("\n").strip().split("--------------------")
            for parsed_text in parsed_texts:
                parsed_text = parsed_text.strip("\n").strip()
                try:
                    text_json = json.loads(parsed_text)
                    text_example = {"Sentence": text_json["Sentence"]}
                    for entity_name in args.entity_list:
                        entities = []
                        for entity in text_json[entity_name]:
                            if entity in text_json["Sentence"] or entity.lower() in text_json["Sentence"].lower():
                                entities.append(entity)
                        text_example[entity_name] = entities
                    f.write(json.dumps(text_example) + '\n')

                    example_cnt += 1
                except:
                    print("Decode Error!", parsed_text)
                    pass                                
        print("=========================")
        print(f"# Examples / Total: {example_cnt} / {args.n_sample}")
        j += 1

    

if __name__ == '__main__':
    main(args)

